Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

freeze_backbone and freeze_decoder in Trainers #1290

Merged
merged 10 commits into from
Apr 26, 2023

Conversation

isaaccorley
Copy link
Collaborator

@isaaccorley isaaccorley commented Apr 25, 2023

This PR adds a freeze_backbone arg to all trainers and a freeze_decoder arg to the SemanticSegmentationTask and PixelwiseRegressionTask. This is particularly useful if wanting to fine-tune (aka linear probe) only the classifier head of a pretrained model or only the decoder or segmentation head in an encoder/decoder architecture.

@isaaccorley isaaccorley self-assigned this Apr 25, 2023
@isaaccorley isaaccorley changed the title freeze_backbone in ClassificationTask freeze_backbone in ClassificationTask and RegressionTask Apr 25, 2023
@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Apr 25, 2023
@adamjstewart adamjstewart added this to the 0.5.0 milestone Apr 25, 2023
@isaaccorley isaaccorley requested review from adamjstewart, calebrob6 and nilsleh and removed request for adamjstewart and calebrob6 April 25, 2023 21:37
@adamjstewart
Copy link
Collaborator

Can you do other trainers too? For this paper, all of our tasks will be semantic segmentation, but might as well do all trainers if it's easy.

@isaaccorley isaaccorley force-pushed the trainers/linear-probe branch from be4cd0d to 4fa65fb Compare April 25, 2023 23:08
@isaaccorley isaaccorley changed the title freeze_backbone in ClassificationTask and RegressionTask freeze_backbone and freeze_decoder in Trainers Apr 25, 2023
Copy link
Member

@calebrob6 calebrob6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sick

@isaaccorley isaaccorley merged commit 698d2b5 into microsoft:main Apr 26, 2023
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the parameter names. Also not sure if freezing the ResNet encoder in a U-Net counts as linear probing or not. Same with classification/regression. Might be easier just to freeze the entire model and add a single layer MLP to the end.

torchgeo/trainers/classification.py Show resolved Hide resolved
torchgeo/trainers/classification.py Show resolved Hide resolved
torchgeo/trainers/regression.py Show resolved Hide resolved
@@ -110,7 +124,9 @@ class and used with 'ce' loss
*encoder_weights* to *weights*.

.. versionadded: 0.5
The *class_weights* parameter.
The *class_weights*, *freeze_backbone*,
and *freeze_decoder* parameters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not documented.

torchgeo/trainers/segmentation.py Show resolved Hide resolved
torchgeo/trainers/classification.py Show resolved Hide resolved
@isaaccorley
Copy link
Collaborator Author

Can you reference some papers that freeze the entire UNet and add fine-tune an MLP on top? I have never seen this before. Also it would make more sense to fine tune a conv layer not a MLP.

Could you also suggest better parameter names?

@isaaccorley isaaccorley deleted the trainers/linear-probe branch April 26, 2023 17:26
@adamjstewart
Copy link
Collaborator

I'm only talking about classification/regression. No idea what the usual practice is for linear probing of semantic segmentation.

Will think about parameter names, don't have a better name off the top of my head. I still kind of like linear_probing because it makes it clear why you would want to use the parameter.

@isaaccorley
Copy link
Collaborator Author

Disagree about the naming. Freeze_backbone is the action we are performing. Linear probing is just the downstream task a user may choose to perform with the classifier with a frozen backbone. There may be other potential reasons for freezing the backbone other than linear probing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants